/*
 * Copyright 2020 eBlocker Open Source UG (haftungsbeschraenkt)
 *
 * Licensed under the EUPL, Version 1.2 or - as soon they will be
 * approved by the European Commission - subsequent versions of the EUPL
 * (the "License"); You may not use this work except in compliance with
 * the License. You may obtain a copy of the License at:
 *
 *   https://joinup.ec.europa.eu/page/eupl-text-11-12
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" basis,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied. See the License for the specific language governing
 * permissions and limitations under the License.
 */
package org.eblocker.lists.malware;

import org.eblocker.server.common.malware.MalwareEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;

class MalwarePatrolProvider implements MalwareProvider {

    private static final Logger log = LoggerFactory.getLogger(MalwarePatrolProvider.class);

    private static final Pattern hostPortPattern = Pattern.compile("^([^:/]*):(\\d+)");
    private static final Pattern IP_PATTERN = Pattern.compile("^\\d+\\.\\d+\\.\\d+\\.\\d+$");

    private static final String DNS_ERROR_RATIO = "provider.malwarepatrol.dns.errorRatio";
    private static final String DNS_THREADS = "provider.malwarepatrol.dns.threads";
    private static final String DNS_TIMEOUT_SECONDS = "provider.malwarepatrol.dns.timeoutSeconds";

    private final MalwarePatrolDownloader downloader;
    private final DnsResolver resolver;

    private final float dnsErrorRatio;
    private final int dnsThreads;
    private final int dnsTimeoutSeconds;

    public MalwarePatrolProvider(Properties properties, MalwarePatrolDownloader downloader, DnsResolver resolver) {
        this.downloader = downloader;
        this.resolver = resolver;

        dnsErrorRatio = Float.parseFloat(properties.getProperty(DNS_ERROR_RATIO));
        dnsThreads = Integer.parseInt(properties.getProperty(DNS_THREADS));
        dnsTimeoutSeconds = Integer.parseInt(properties.getProperty(DNS_TIMEOUT_SECONDS));
    }

    @Override
    public MalwareEntries getMalwareEntries() throws MalwareListException {
        try {
            List<MalwareEntry> filteredUrls = downloader.retrieveEntries();
            Map<String, Set<Integer>> filteredIpPorts = resolveHostsWithNonStandPorts(filteredUrls);
            return new MalwareEntries(filteredUrls, filteredIpPorts);
        } catch (IOException e) {
            throw new MalwareListException("i/o error", e);
        }
    }

    private Map<String, Set<Integer>> resolveHostsWithNonStandPorts(List<MalwareEntry> entries) throws NameResolutionException {
        Map<Boolean, List<HostPortTuple>> hosts = entries.stream()
                .map(e -> hostPortPattern.matcher(e.getUrl()))
                .filter(Matcher::find)
                .filter(m -> !"80".equals(m.group(2)) && !"443".equals(m.group(2)))
                .map(m -> new HostPortTuple(m.group(1), Integer.parseInt(m.group(2))))
                .distinct()
                .collect(Collectors.groupingBy(t -> IP_PATTERN.matcher(t.host).matches()));
        hosts.putIfAbsent(false, Collections.emptyList());
        hosts.putIfAbsent(true, Collections.emptyList());

        log.info("{} out of {} entries use non-standard port", hosts.get(false).size() + hosts.get(true).size(), entries.size());
        log.info("{} entries needs to be resolved", hosts.get(false).size());

        AtomicInteger errors = new AtomicInteger();
        ConcurrentLinkedQueue<HostPortTuple> resolved = new ConcurrentLinkedQueue<>();
        ExecutorService executorService = createExecutorService();
        hosts.get(false).stream().forEach(t -> executorService.submit(() -> resolve(t, resolved, errors)));

        executorService.shutdown();
        try {
            if (!executorService.awaitTermination(dnsTimeoutSeconds, TimeUnit.SECONDS)) {
                executorService.shutdownNow();
                throw new NameResolutionException("failed to resolve hosts in " + dnsTimeoutSeconds + " seconds");
            }
        } catch (InterruptedException e) {
            executorService.shutdownNow();
            Thread.currentThread().interrupt();
            throw new NameResolutionException("interrupted while waiting for host resolution", e);
        }

        if (errors.get() > 0) {
            float ratio = (float) errors.get() / hosts.get(false).size();
            log.warn("{} of {} hosts did not resolve ({})",errors, hosts.get(false).size(), ratio);

            if (ratio >= dnsErrorRatio) {
                throw new NameResolutionException("failure ratio too high: " + ratio);
            }
        }
        log.info("{} hosts resolved to {} ip-addresses", hosts.get(false).size(), resolved.size());

        return Stream.concat(hosts.get(true).stream(), resolved.stream())
                .collect(Collectors.groupingBy(t -> t.host))
                .entrySet().stream()
                .collect(Collectors.toMap(
                        Map.Entry::getKey,
                        t -> t.getValue().stream()
                                .map(e -> e.port)
                                .collect(Collectors.toSet())));
    }

    private ExecutorService createExecutorService() {
        // custom thread-factory is used to allow vm-exit before all threads are done. Otherwise exit on timeout
        // will be delayed until currently running threads are done.
        final AtomicInteger i = new AtomicInteger();
        return Executors.newFixedThreadPool(dnsThreads, r -> {
            Thread thread = new Thread(r, "dns-resolver-" + i.incrementAndGet());
            thread.setDaemon(true);
            return thread;
        });
    }

    private void resolve(HostPortTuple t, Collection<HostPortTuple> results, AtomicInteger errors) {
        try {
            log.debug("resolving {}", t);

            for(InetAddress inetAddress : resolver.resolve(t.host)) {
                if (inetAddress instanceof Inet4Address) {
                    log.debug("{}: {}", t, inetAddress.getHostAddress());
                    results.add(new HostPortTuple(inetAddress.getHostAddress(), t.port));
                } else {
                    log.debug("{}: {} (ignored)", t, inetAddress.getHostAddress());
                }
            }
        } catch (UnknownHostException e) {
            log.warn("failed to resolve {}", t.host, e);
            errors.incrementAndGet();
        }
    }

    static class NameResolutionException extends MalwareListException {
        NameResolutionException(String message) {
            super(message);
        }

        NameResolutionException(String message, Throwable cause) {
            super(message, cause);
        }
    }

    private class HostPortTuple {
        String host;
        Integer port;

        HostPortTuple(String host, Integer port) {
            this.host = host;
            this.port = port;
        }

        @Override
        public String toString() {
            return host + ":" + port;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }

            if (o == null || getClass() != o.getClass()) {
                return false;
            }

            HostPortTuple that = (HostPortTuple) o;
            if (!host.equals(that.host)) {
                return false;
            }
            return port.equals(that.port);
        }

        @Override
        public int hashCode() {
            int result = host.hashCode();
            result = 31 * result + port.hashCode();
            return result;
        }
    }
}
